{-# LANGUAGE OverloadedStrings #-}
module Labeler 
    ( ModelData(..)
    , Config(..)
    , train
    , predict
    )
    
where

import qualified Data.Map as Map
import qualified Data.Set as Set
import qualified Data.IntMap as IntMap
import qualified Data.IntSet as IntSet
import Data.List (foldl',tails)
import Data.Maybe (fromMaybe)
import Helper.ListZipper
import qualified Perceptron.Sequence as P
import Perceptron.Sequence (Options(..))
import CorpusReader (Token)
import Helper.Utils (splitWith,uniq)
import Text.Printf
import Helper.Atom
import Control.Monad.RWS
import Features (inputFeatures,features,maybeFeatures,outputFeatures
                ,indexFeatures)
import qualified Data.Array as A
import qualified Data.Vector.Unboxed as V
import qualified Data.Binary as Binary
import qualified Helper.Text as Text
import Helper.Text(Txt)
import Data.Char
import Data.Maybe (catMaybes)
import Config 


data ModelData = ModelData { model :: P.Model
                           , config :: Config
                           } 
instance Binary.Binary ModelData where
    get = return ModelData `ap` Binary.get `ap` Binary.get
    put (ModelData a b) = Binary.put a >> Binary.put b 




--  Main exported functions 
predict :: ModelData -> [[ListZipper Token]] -> [[Txt]]
predict m testdat = 
    let bounds = oFeatBounds . P.options . model $ m
    in fst . flip runAtoms (atomTable . config $ m) $
        do flip mapM testdat $ \x -> 
               do x' <- mapM (maybeFeatures bounds (config m)) $ x
                  predict' (P.decode (model m)) $ x'

train :: Config 
      -> [([ListZipper Token],[Txt])]
      -> [([ListZipper Token],[Txt])]
      -> ModelData
train conf traindat heldout = 
        let ((m,_predicted),_atoms) = 
                 runAtoms (run conf 
                               traindat 
                               heldout) 
                              $ empty
        in m

-- Implementation
type F = Int
type Tag = Int
tagDictionary ::  IntSet.IntSet 
              -> Int 
              -> [([V.Vector Int], [F])] 
              -> IntMap.IntMap [Tag]
tagDictionary indexFeatureSet wmin trainset = 
    let tags = concat . map snd $ trainset
        ws   =   catMaybes  
               . map (V.find (`IntSet.member` indexFeatureSet))
               . concat 
               . map fst 
               $ trainset
        count_ws = IntMap.fromListWith (+) [ (w,1) | w <- ws ]
        dict =   IntMap.map Set.toList
               . IntMap.fromListWith Set.union 
               $ [ (w,Set.singleton t) | (w,t) <- zip ws tags 
               , count_ws IntMap.! w >= wmin]
    in dict == dict `seq` dict

pruneLabels :: Int -> [(x,[Txt])] -> [(x,[Txt])]
pruneLabels lim xys =
    let freq =   Map.fromListWith (+)
               . map (\y -> (y,1))
               . concat
               . map snd
               $ xys
        undet = "UNDETERMINED"
    in [ (x,[ if freq Map.! yi < lim then undet else yi | yi <- y ]) 
         | (x,y) <- xys ]

run :: (Functor m, MonadAtoms m) =>
       Config
    ->  [([ListZipper Token], [Txt])]
    ->  [([ListZipper Token], [Txt])]
    -> m (ModelData, [[Txt]])
run conf trainset_in testset_in = do
  let --trainset_in = pruneLabels (minLabelFreq conf) trainset_in_full
      ys = uniq . concat . map snd $ trainset_in :: [Txt]
  ys' <- mapM toAtom ys
  outm <- mkOutputFeatureAtoms . map snd $ trainset_in 
  let size = outputFeatureCount outm + 
             maybe (estimateFeatureCount conf . map fst $ trainset_in)
                   id
                   (flagHashMaxSize . flags $ conf)
      bounds = if flagHash . flags $ conf 
               then Just (0,size)
               else Nothing
  trainset <- mapM (mkfs $ features bounds conf) trainset_in
  testset <- mapM (mkfs $ maybeFeatures bounds conf) testset_in 
  tab <- table
  let indexFeatureSet = indexFeatures tab
      conf' = conf {atomTable = tab }
      opts = Options { oYMap = outm
                     , oIndexSet =  indexFeatureSet
                     , oYDict = tagDictionary indexFeatureSet 
                                     (flagMinFeatCount . flags $ conf') trainset
                     , oYs   = ys'
                     , oBeam = flagBeam . flags $ conf
                     , oRate = flagRate . flags $ conf
                     , oEpochs = flagIter . flags $ conf
                     , oFeatBounds = bounds
                     }
      m = P.train opts testset formatEval trainset
  ps <- mapM (predict' (P.decode m . fst)) testset
  return $ (ModelData { model = m , config = conf' }
           ,ps)

predict' :: (MonadAtoms m) =>
            (t -> [Int]) -> t -> m [Txt]
predict' dec x = do
        let xr = dec  x
        xr'<- mapM fromAtom xr
        return xr'

mkOutputFeatureAtoms :: (MonadAtoms m) => [[Txt]] -> m P.YMap
mkOutputFeatureAtoms yss = do
  let unigrams = map return . uniq . concat $ yss
      bigrams = uniq $ concat [   filter ((==2) . length) 
                                . map (take 2) 
                                . tails 
                                $ ys | ys <- yss ]
  unigramis <- mapM (mapM toAtom) unigrams
  bigramis  <- mapM (mapM toAtom) bigrams
  let ys = map head unigramis
      (lo,hi) = (minimum ys,maximum ys)
  unigramfs <- mapM (mapM toAtom) . map outputFeatures $ unigrams
  bigramfs  <- mapM (mapM toAtom) . map outputFeatures $ bigrams
  zerofs <- mapM toAtom . outputFeatures $ []
  let ymap1 =   A.accumArray (V.++) V.empty (lo,hi) 
              . zip (map head unigramis) 
              . map V.fromList
              $ unigramfs
      ymap2 =    A.accumArray (V.++) V.empty ((lo,lo),(hi,hi)) 
               . zip (map (\ [y1,y2] -> (y1,y2)) bigramis)
               . map V.fromList
               $ bigramfs 
  return $ (V.fromList zerofs, ymap1, ymap2)

outputFeatureCount :: P.YMap -> Int
outputFeatureCount (zero,uni,bi) = 
    maximum  (V.toList zero 
              ++ (concatMap V.toList . A.elems $ uni)
              ++ (concatMap V.toList . A.elems $ bi ))
                             
mkfs :: (MonadAtoms m) => 
        (ListZipper Token -> m (V.Vector F))
     ->   ([ListZipper Token], [Txt]) 
     ->   m ([V.Vector F], [Tag])
mkfs f (x,y) = do
  fs <- mapM f x
  fs == fs `seq` return ()
  y' <- mapM toAtom y
  y' == y' `seq` return ()
  return $ (fs,y')

estimateFeatureCount :: Config -> [[ListZipper Token]] -> Int
estimateFeatureCount conf xs = 
    let len = length xs
        size = min len . flagHashSample . flags $ conf
        factor = length xs `div` size
        tokno  = (factor *) 
                 . length 
                 . uniq
                 . concatMap (concatMap (inputFeatures conf))
                 . take size
                 $ xs
    in tokno

formatEval :: P.Eval 
formatEval 0 _ _        = printf "%10s %10s %10s" ("Iter"::String) 
                                                  ("Train"::String)
                                                  ("Heldout"::String)
formatEval i ss heldout = printf "%10d %10.4f %10.4f" i (eval ss) (eval heldout)
    

eval :: Eq a => [([a],[a])] -> Double
eval ys = 
    let corr =   foldl' (+) 0 
               . concat
               $ [ [ 1 | (y,y') <- ys , (yi,yi') <- zip y y' 
                                  , yi == yi' ] ]
    in corr / fromIntegral (length . concatMap fst $ ys)
